Implement copy-on write

背景

xv6 使用 fork() 系统调用创建子进程时,需要将父进程的地址空间进行 深拷贝 ,即将页表和实际物理空间同时进行拷贝,以实现父进程和子进程地址空间的独立性。但很多时候,如 shell 程序,fork() 通常与 exec() 搭配使用,首先使用 fork() 创建子进程,随后在子进程中使用 exec() 将指定的程序加载到当前地址空间,这样在 fork() 中进行的地址空间拷贝就白白浪费了。

本实现要求实现一个写时复制(copy-on write)的 fork() 系统调用。具体来说,在进行虚拟内存拷贝时,不直接进行物理内存的拷贝,只是将父进程的页表复制给子进程,这样子进程和父进程的每个虚拟页面都指向了同一个物理页面,当子进程需要对某个虚拟页面进行写入时,为了保证父进程和子进程之间的独立性,子进程此时将进行物理内存的分配和拷贝,再进行写入。

实现方案

根据提示,可以将上述的写时复制的思路用 异常 的方式来实现。

首先可以利用页表项的 flags 中的 RSW 位来表示页表项是否为 COW 页,以便后续的异常处理。

修改 uvmcopy() ,将物理页面的分配操作去除,只是进行页表的拷贝,并将父进程和子进程的对应页表项的 PTE_W 置 0(以便在对 COW 页进行写入时陷入内核)、PTE_COW 置 1。

修改 usertrap(),当陷入内核时,内核通过查看 scause 寄存器(见下图)以及页表项的 PTE_W 和 PTE_COW 位,识别到陷入原因是发生在 COW 页上的 store page fault(寄存器值为 15)时,进行对应的异常处理:使用 kalloc() 为其分配物理页面,并将其页表项指向的物理地址数据拷贝到新分配的物理地址下,实现物理内存的拷贝。此时由于页表映射发生了改变,需要插入新的页表项,并删除旧的页表项。在处理了 COW 异常之后,该页面将不再是一个 COW 页,因此需要将 PTE_W 置 1、PTE_COW 置 0。

为了后续实现的方便,可以将 COW 页的判断和 COW 页的异常处理分别封装为两个函数:

int iscowpage(pagetable_t pgtbl, uint64 va) {
    if (va >= MAXVA) return 0;
    pte_t *pte = walk(pgtbl, va, 0);
    if (pte == 0) return 0;
    if ((*pte & PTE_V) == 0) return 0;
    if ((*pte & PTE_U) == 0) return 0;
    return *pte & PTE_COW;
}

int cowfault(pagetable_t pagetable, uint64 va) {
    uint64 va0 = PGROUNDDOWN(va);
    pte_t* pte;
    if((pte = walk(pagetable, va0, 0)) == 0) return -1;

    uint64 flags = PTE_FLAGS(*pte);
    uint64 pa0 = PTE2PA(*pte);

    flags &= (~PTE_COW); // clear COW bit
    flags |= PTE_W;      // set write bit

    uint64 mem;
    if ((mem = (uint64)kalloc()) == 0) return -1;
    memmove((void *)mem, (void *)pa0, PGSIZE);

    // remove old PTE
    uvmunmap(pagetable, va0, 1, 1);

    // install new PTE
    if(mappages(pagetable, va0, PGSIZE, mem, flags) < 0){
        kfree((void *)mem);
        return -1;
    }
    return 0;
}

此外,还需要为每个物理页面引入 引用计数(reference count) ,页面创建时计数为 1,每次添加或移除指向该物理地址的页表项都增加或减少引用计数,当引用计数为 0 时释放该物理页面。这里有一个实现的技巧:将引用计数的减少放到 kfree() 中,在 kfree() 中根据引用计数的大小决定是否释放物理页面。

最后,也是很容易忽视的一点,修改 copyout() 以实现对 COW 页的支持。刚开始看到这个提示的时候我很疑惑,前面的工作貌似已经足够实现 COW 了,为什么还要修改 copyout?原来 xv6 对 COW 页进行写时复制都是基于 store page fault,即当尝试写入一个 PTE_W 为 0 的页面时触发异常,导致陷入内核,再由内核进行 COW 页面的异常处理,其中陷入内核的操作是由硬件自动来完成的,具体来说,是在虚实地址转换阶段由 MMU 来完成的。而 copyout() 是运行在内核态下的函数,其地址转换是由内核中的函数 walk() 来实现的,因而不会自动触发异常并交由异常处理程序来处理,而需要手动来完成。由于前面已经将 COW 页的判断和处理封装成了函数,因此对 copyout() 的修改很简单:

if (iscowpage(pagetable, va0)) {
    cowfault(pagetable, va0);
}

代码

diff --git a/kernel/defs.h b/kernel/defs.h
index 3564db4..f5a9d8d 100644
--- a/kernel/defs.h
+++ b/kernel/defs.h
@@ -63,6 +63,7 @@ void            ramdiskrw(struct buf*);
 void*           kalloc(void);
 void            kfree(void *);
 void            kinit(void);
+void            incrfcount(void*);
 
 // log.c
 void            initlog(int, struct superblock*);
@@ -145,6 +146,8 @@ void            trapinit(void);
 void            trapinithart(void);
 extern struct spinlock tickslock;
 void            usertrapret(void);
+int             iscowpage(pagetable_t, uint64);
+int             cowfault(pagetable_t, uint64);
 
 // uart.c
 void            uartinit(void);
@@ -170,6 +173,7 @@ uint64          walkaddr(pagetable_t, uint64);
 int             copyout(pagetable_t, uint64, char *, uint64);
 int             copyin(pagetable_t, char *, uint64, uint64);
 int             copyinstr(pagetable_t, char *, uint64, uint64);
+pte_t*          walk(pagetable_t, uint64, int);
 
 // plic.c
 void            plicinit(void);
diff --git a/kernel/kalloc.c b/kernel/kalloc.c
index fa6a0ac..5872b85 100644
--- a/kernel/kalloc.c
+++ b/kernel/kalloc.c
@@ -14,6 +14,11 @@ void freerange(void *pa_start, void *pa_end);
 extern char end[]; // first address after kernel.
                    // defined by kernel.ld.
 
+#define PA2RFIDX(pa) ((((uint64)pa) - KERNBASE) / PGSIZE)
+
+int rfcount[(PHYSTOP - KERNBASE) / PGSIZE];
+struct spinlock rflock;
+
 struct run {
   struct run *next;
 };
@@ -27,6 +32,7 @@ void
 kinit()
 {
   initlock(&kmem.lock, "kmem");
+  initlock(&rflock, "rflock");
   freerange(end, (void*)PHYSTOP);
 }
 
@@ -51,15 +57,17 @@ kfree(void *pa)
   if(((uint64)pa % PGSIZE) != 0 || (char*)pa < end || (uint64)pa >= PHYSTOP)
     panic("kfree");
 
-  // Fill with junk to catch dangling refs.
-  memset(pa, 1, PGSIZE);
-
-  r = (struct run*)pa;
-
-  acquire(&kmem.lock);
-  r->next = kmem.freelist;
-  kmem.freelist = r;
-  release(&kmem.lock);
+  acquire(&rflock);
+  if(--rfcount[PA2RFIDX(pa)] <= 0){
+    memset(pa, 1, PGSIZE);
+    // Fill with junk to catch dangling refs.
+    r = (struct run*)pa;
+    acquire(&kmem.lock);
+    r->next = kmem.freelist;
+    kmem.freelist = r;
+    release(&kmem.lock);
+  }
+  release(&rflock);
 }
 
 // Allocate one 4096-byte page of physical memory.
@@ -76,7 +84,15 @@ kalloc(void)
     kmem.freelist = r->next;
   release(&kmem.lock);
 
-  if(r)
+  if(r) {
     memset((char*)r, 5, PGSIZE); // fill with junk
+    rfcount[PA2RFIDX(r)] = 1;
+  }
   return (void*)r;
 }
+
+void incrfcount(void* pa){
+  acquire(&rflock);
+  ++rfcount[PA2RFIDX(pa)];
+  release(&rflock);
+}
\ No newline at end of file
diff --git a/kernel/riscv.h b/kernel/riscv.h
index 1691faf..a6ba9e7 100644
--- a/kernel/riscv.h
+++ b/kernel/riscv.h
@@ -343,6 +343,8 @@ sfence_vma()
 #define PTE_W (1L << 2)
 #define PTE_X (1L << 3)
 #define PTE_U (1L << 4) // 1 -> user can access
+#define PTE_COW (1L << 8) // 1 -> is a COW page
+
 
 // shift a physical address to the right place for a PTE.
 #define PA2PTE(pa) ((((uint64)pa) >> 12) << 10)
diff --git a/kernel/trap.c b/kernel/trap.c
index a63249e..0fb7687 100644
--- a/kernel/trap.c
+++ b/kernel/trap.c
@@ -29,6 +29,42 @@ trapinithart(void)
   w_stvec((uint64)kernelvec);
 }
 
+
+int iscowpage(pagetable_t pgtbl, uint64 va) {
+  if (va >= MAXVA) return 0;
+  pte_t *pte = walk(pgtbl, va, 0);
+  if (pte == 0) return 0;
+  if ((*pte & PTE_V) == 0) return 0;
+  if ((*pte & PTE_U) == 0) return 0;
+  return *pte & PTE_COW;
+}
+
+int cowfault(pagetable_t pagetable, uint64 va) {
+  uint64 va0 = PGROUNDDOWN(va);
+  pte_t* pte;
+  if((pte = walk(pagetable, va0, 0)) == 0) return -1;
+  
+  uint64 flags = PTE_FLAGS(*pte);
+  uint64 pa0 = PTE2PA(*pte);
+
+  flags &= (~PTE_COW); // clear COW bit
+  flags |= PTE_W;      // set write bit
+
+  uint64 mem;
+  if ((mem = (uint64)kalloc()) == 0) return -1;
+  memmove((void *)mem, (void *)pa0, PGSIZE);
+
+  // remove old PTE
+  uvmunmap(pagetable, va0, 1, 1);
+  
+  // install new PTE
+  if(mappages(pagetable, va0, PGSIZE, mem, flags) < 0){
+    kfree((void *)mem);
+    return -1;
+  }
+  return 0;
+}
+
 //
 // handle an interrupt, exception, or system call from user space.
 // called from trampoline.S
@@ -67,7 +103,12 @@ usertrap(void)
     syscall();
   } else if((which_dev = devintr()) != 0){
     // ok
-  } else {
+  } else if (r_scause() == 15 && iscowpage(p->pagetable, r_stval())) {
+    if (cowfault(p->pagetable, r_stval()) < 0) {
+      p->killed = 1;
+    }
+  }
+  else {
     printf("usertrap(): unexpected scause %p pid=%d\n", r_scause(), p->pid);
     printf("            sepc=%p stval=%p\n", r_sepc(), r_stval());
     p->killed = 1;
diff --git a/kernel/vm.c b/kernel/vm.c
index d5a12a0..df0ddde 100644
--- a/kernel/vm.c
+++ b/kernel/vm.c
@@ -303,22 +303,20 @@ uvmcopy(pagetable_t old, pagetable_t new, uint64 sz)
   pte_t *pte;
   uint64 pa, i;
   uint flags;
-  char *mem;
 
   for(i = 0; i < sz; i += PGSIZE){
     if((pte = walk(old, i, 0)) == 0)
       panic("uvmcopy: pte should exist");
     if((*pte & PTE_V) == 0)
       panic("uvmcopy: page not present");
+    *pte &= ~PTE_W;   // set write bit
+    *pte |= PTE_COW;  // clear COW bit
     pa = PTE2PA(*pte);
     flags = PTE_FLAGS(*pte);
-    if((mem = kalloc()) == 0)
-      goto err;
-    memmove(mem, (char*)pa, PGSIZE);
-    if(mappages(new, i, PGSIZE, (uint64)mem, flags) != 0){
-      kfree(mem);
+    if(mappages(new, i, PGSIZE, pa, flags) != 0){
       goto err;
     }
+    incrfcount((void*)pa); // increment reference count to pa
   }
   return 0;
 
@@ -350,6 +348,9 @@ copyout(pagetable_t pagetable, uint64 dstva, char *src, uint64 len)
 
   while(len > 0){
     va0 = PGROUNDDOWN(dstva);
+    if (iscowpage(pagetable, va0)) {
+      cowfault(pagetable, va0);
+    }
     pa0 = walkaddr(pagetable, va0);
     if(pa0 == 0)
       return -1;
diff --git a/time.txt b/time.txt
new file mode 100644
index 0000000..209e3ef
--- /dev/null
+++ b/time.txt
@@ -0,0 +1 @@
+20